{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# PaddleHub之《青春有你2》进行二分类"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 一、任务简介\n",
    "\n",
    "图像分类是计算机视觉的重要领域，它的目标是将图像分类到预定义的标签。近期，许多研究者提出很多不同种类的神经网络，并且极大的提升了分类算法的性能。本文以自己创建的数据集：青春有你2中选手识别为例子，介绍如何使用PaddleHub进行图像分类任务。\n",
    "\n",
    "<div  align=\"center\">   \n",
    "<img src=\"https://ai-studio-static-online.cdn.bcebos.com/dbaf5ba718f749f0836c342fa67f6d7954cb89f3796b4c8b981a03a1635f2fbe\" width = \"350\" height = \"250\" align=center />\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CPU_NUM=1\n"
     ]
    }
   ],
   "source": [
    "#CPU环境启动请务必执行该指令\n",
    "%set_env CPU_NUM=1 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: paddlehub==1.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (1.6.0)\n",
      "Requirement already satisfied: chardet==3.0.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (3.0.4)\n",
      "Requirement already satisfied: nltk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (3.4.5)\n",
      "Requirement already satisfied: sentencepiece in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (0.1.85)\n",
      "Requirement already satisfied: gunicorn>=19.10.0; sys_platform != \"win32\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (20.0.4)\n",
      "Requirement already satisfied: tb-paddle in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (0.3.6)\n",
      "Requirement already satisfied: flake8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (3.7.9)\n",
      "Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (6.2.0)\n",
      "Requirement already satisfied: colorlog in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (4.1.0)\n",
      "Requirement already satisfied: protobuf>=3.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (3.10.0)\n",
      "Requirement already satisfied: cma==2.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (2.7.0)\n",
      "Requirement already satisfied: yapf==0.26.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (0.26.0)\n",
      "Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (5.1.2)\n",
      "Requirement already satisfied: pandas; python_version >= \"3\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (0.23.4)\n",
      "Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (2.22.0)\n",
      "Requirement already satisfied: tensorboard>=1.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (2.1.0)\n",
      "Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (4.1.1.26)\n",
      "Requirement already satisfied: flask>=1.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (1.1.1)\n",
      "Requirement already satisfied: numpy; python_version >= \"3\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (1.16.4)\n",
      "Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (1.21.0)\n",
      "Requirement already satisfied: six>=1.10.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub==1.6.0) (1.12.0)\n",
      "Requirement already satisfied: setuptools>=3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gunicorn>=19.10.0; sys_platform != \"win32\"->paddlehub==1.6.0) (41.4.0)\n",
      "Requirement already satisfied: moviepy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tb-paddle->paddlehub==1.6.0) (1.0.1)\n",
      "Requirement already satisfied: pycodestyle<2.6.0,>=2.5.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8->paddlehub==1.6.0) (2.5.0)\n",
      "Requirement already satisfied: entrypoints<0.4.0,>=0.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8->paddlehub==1.6.0) (0.3)\n",
      "Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8->paddlehub==1.6.0) (0.6.1)\n",
      "Requirement already satisfied: pyflakes<2.2.0,>=2.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8->paddlehub==1.6.0) (2.1.1)\n",
      "Requirement already satisfied: python-dateutil>=2.5.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas; python_version >= \"3\"->paddlehub==1.6.0) (2.8.0)\n",
      "Requirement already satisfied: pytz>=2011k in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas; python_version >= \"3\"->paddlehub==1.6.0) (2019.3)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->paddlehub==1.6.0) (1.25.6)\n",
      "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->paddlehub==1.6.0) (2.8)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->paddlehub==1.6.0) (2019.9.11)\n",
      "Requirement already satisfied: absl-py>=0.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard>=1.15->paddlehub==1.6.0) (0.8.1)\n",
      "Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard>=1.15->paddlehub==1.6.0) (0.16.0)\n",
      "Requirement already satisfied: google-auth<2,>=1.6.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard>=1.15->paddlehub==1.6.0) (1.10.0)\n",
      "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard>=1.15->paddlehub==1.6.0) (0.4.1)\n",
      "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard>=1.15->paddlehub==1.6.0) (3.1.1)\n",
      "Requirement already satisfied: grpcio>=1.24.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard>=1.15->paddlehub==1.6.0) (1.26.0)\n",
      "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from tensorboard>=1.15->paddlehub==1.6.0) (0.33.6)\n",
      "Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub==1.6.0) (7.0)\n",
      "Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub==1.6.0) (2.10.1)\n",
      "Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub==1.6.0) (1.1.0)\n",
      "Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->paddlehub==1.6.0) (0.10.0)\n",
      "Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->paddlehub==1.6.0) (1.4.10)\n",
      "Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->paddlehub==1.6.0) (2.0.1)\n",
      "Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->paddlehub==1.6.0) (1.3.4)\n",
      "Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->paddlehub==1.6.0) (16.7.9)\n",
      "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->paddlehub==1.6.0) (0.23)\n",
      "Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->paddlehub==1.6.0) (1.3.0)\n",
      "Requirement already satisfied: imageio-ffmpeg>=0.2.0; python_version >= \"3.4\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from moviepy->tb-paddle->paddlehub==1.6.0) (0.3.0)\n",
      "Requirement already satisfied: proglog<=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from moviepy->tb-paddle->paddlehub==1.6.0) (0.1.9)\n",
      "Requirement already satisfied: decorator<5.0,>=4.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from moviepy->tb-paddle->paddlehub==1.6.0) (4.4.0)\n",
      "Requirement already satisfied: imageio<3.0,>=2.5; python_version >= \"3.4\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from moviepy->tb-paddle->paddlehub==1.6.0) (2.6.1)\n",
      "Requirement already satisfied: tqdm<5.0,>=4.11.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from moviepy->tb-paddle->paddlehub==1.6.0) (4.36.1)\n",
      "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.15->paddlehub==1.6.0) (4.0.0)\n",
      "Requirement already satisfied: rsa<4.1,>=3.1.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.15->paddlehub==1.6.0) (4.0)\n",
      "Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.15->paddlehub==1.6.0) (0.2.7)\n",
      "Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.15->paddlehub==1.6.0) (1.3.0)\n",
      "Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.0->paddlehub==1.6.0) (1.1.1)\n",
      "Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version < \"3.8\"->pre-commit->paddlehub==1.6.0) (0.6.0)\n",
      "Requirement already satisfied: pyasn1>=0.1.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from rsa<4.1,>=3.1.4->google-auth<2,>=1.6.3->tensorboard>=1.15->paddlehub==1.6.0) (0.4.8)\n",
      "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.15->paddlehub==1.6.0) (3.1.0)\n",
      "Requirement already satisfied: more-itertools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata; python_version < \"3.8\"->pre-commit->paddlehub==1.6.0) (7.2.0)\n"
     ]
    }
   ],
   "source": [
    "#安装paddlehup\n",
    "!pip install paddlehub==1.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 二、任务实践\n",
    "### Step1、基础工作\n",
    "\n",
    "加载数据文件\n",
    "\n",
    "导入python包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Archive:  /home/aistudio/data/data30796/data.zip\n",
      "  inflating: data/validate_list.txt  \n",
      "  inflating: __MACOSX/data/._validate_list.txt  \n",
      "  inflating: data/.DS_Store          \n",
      "  inflating: __MACOSX/data/._.DS_Store  \n",
      "   creating: data/xu/\n",
      "  inflating: __MACOSX/data/._xu      \n",
      "  inflating: data/train_list.txt     \n",
      "  inflating: __MACOSX/data/._train_list.txt  \n",
      "   creating: data/yu/\n",
      "  inflating: __MACOSX/data/._yu      \n",
      "  inflating: data/label_list.txt     \n",
      "  inflating: __MACOSX/data/._label_list.txt  \n",
      "  inflating: data/test_list.txt      \n",
      "  inflating: __MACOSX/data/._test_list.txt  \n",
      "  inflating: data/xu/.DS_Store       \n",
      "  inflating: __MACOSX/data/xu/._.DS_Store  \n",
      "  inflating: data/xu/8.jpg           \n",
      "  inflating: __MACOSX/data/xu/._8.jpg  \n",
      "  inflating: data/xu/9.jpg           \n",
      "  inflating: __MACOSX/data/xu/._9.jpg  \n",
      "  inflating: data/xu/14.jpg          \n",
      "  inflating: __MACOSX/data/xu/._14.jpg  \n",
      "  inflating: data/xu/15.jpg          \n",
      "  inflating: __MACOSX/data/xu/._15.jpg  \n",
      "  inflating: data/xu/16.jpg          \n",
      "  inflating: __MACOSX/data/xu/._16.jpg  \n",
      "  inflating: data/xu/12.jpg          \n",
      "  inflating: __MACOSX/data/xu/._12.jpg  \n",
      "  inflating: data/xu/13.jpg          \n",
      "  inflating: __MACOSX/data/xu/._13.jpg  \n",
      "  inflating: data/xu/11.jpg          \n",
      "  inflating: __MACOSX/data/xu/._11.jpg  \n",
      "  inflating: data/xu/10.jpg          \n",
      "  inflating: __MACOSX/data/xu/._10.jpg  \n",
      "  inflating: data/xu/4.jpg           \n",
      "  inflating: __MACOSX/data/xu/._4.jpg  \n",
      "  inflating: data/xu/5.jpg           \n",
      "  inflating: __MACOSX/data/xu/._5.jpg  \n",
      "  inflating: data/xu/7.jpg           \n",
      "  inflating: __MACOSX/data/xu/._7.jpg  \n",
      "  inflating: data/xu/6.jpg           \n",
      "  inflating: __MACOSX/data/xu/._6.jpg  \n",
      "  inflating: data/xu/2.jpg           \n",
      "  inflating: __MACOSX/data/xu/._2.jpg  \n",
      "  inflating: data/xu/3.jpg           \n",
      "  inflating: __MACOSX/data/xu/._3.jpg  \n",
      "  inflating: data/xu/1.jpg           \n",
      "  inflating: __MACOSX/data/xu/._1.jpg  \n",
      "  inflating: data/yu/.DS_Store       \n",
      "  inflating: __MACOSX/data/yu/._.DS_Store  \n",
      "  inflating: data/yu/8.jpg           \n",
      "  inflating: __MACOSX/data/yu/._8.jpg  \n",
      "  inflating: data/yu/9.jpg           \n",
      "  inflating: __MACOSX/data/yu/._9.jpg  \n",
      "  inflating: data/yu/14.jpg          \n",
      "  inflating: __MACOSX/data/yu/._14.jpg  \n",
      "  inflating: data/yu/15.jpg          \n",
      "  inflating: __MACOSX/data/yu/._15.jpg  \n",
      "  inflating: data/yu/16.jpg          \n",
      "  inflating: __MACOSX/data/yu/._16.jpg  \n",
      "  inflating: data/yu/12.jpg          \n",
      "  inflating: __MACOSX/data/yu/._12.jpg  \n",
      "  inflating: data/yu/13.jpg          \n",
      "  inflating: __MACOSX/data/yu/._13.jpg  \n",
      "  inflating: data/yu/11.jpg          \n",
      "  inflating: __MACOSX/data/yu/._11.jpg  \n",
      "  inflating: data/yu/10.jpg          \n",
      "  inflating: __MACOSX/data/yu/._10.jpg  \n",
      "  inflating: data/yu/4.jpg           \n",
      "  inflating: __MACOSX/data/yu/._4.jpg  \n",
      "  inflating: data/yu/5.jpg           \n",
      "  inflating: __MACOSX/data/yu/._5.jpg  \n",
      "  inflating: data/yu/7.jpg           \n",
      "  inflating: __MACOSX/data/yu/._7.jpg  \n",
      "  inflating: data/yu/6.jpg           \n",
      "  inflating: __MACOSX/data/yu/._6.jpg  \n",
      "  inflating: data/yu/2.jpg           \n",
      "  inflating: __MACOSX/data/yu/._2.jpg  \n",
      "  inflating: data/yu/3.jpg           \n",
      "  inflating: __MACOSX/data/yu/._3.jpg  \n",
      "  inflating: data/yu/1.jpg           \n",
      "  inflating: __MACOSX/data/yu/._1.jpg  \n"
     ]
    }
   ],
   "source": [
    "!unzip -o /home/aistudio/data/data30796/data.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import paddlehub as hub"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Step2、加载预训练模型\n",
    "\n",
    "接下来我们要在PaddleHub中选择合适的预训练模型来Finetune，由于是图像分类任务，因此我们使用经典的ResNet-50作为预训练模型。PaddleHub提供了丰富的图像分类预训练模型，包括了最新的神经网络架构搜索类的PNASNet，我们推荐您尝试不同的预训练模型来获得更好的性能。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[2021-03-08 16:50:09,220] [    INFO] - Installing resnet_v2_50_imagenet module\u001b[0m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading resnet_v2_50_imagenet\n",
      "[==================================================] 100.00%\n",
      "Uncompress /home/aistudio/.paddlehub/tmp/tmpq8so69zk/resnet_v2_50_imagenet\n",
      "[==================================================] 100.00%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[2021-03-08 16:50:16,483] [    INFO] - Successfully installed resnet_v2_50_imagenet-1.0.1\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "module = hub.Module(name=\"resnet_v2_50_imagenet\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Step3、数据准备\n",
    "\n",
    "接着需要加载图片数据集。我们使用自定义的数据进行体验，请查看[适配自定义数据](https://github.com/PaddlePaddle/PaddleHub/wiki/PaddleHub适配自定义数据完成FineTune)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from paddlehub.dataset.base_cv_dataset import BaseCVDataset\n",
    "   \n",
    "class DemoDataset(BaseCVDataset):\t\n",
    "   def __init__(self):\t\n",
    "       # 数据集存放位置\n",
    "       \n",
    "       self.dataset_dir = \"data\"\n",
    "       super(DemoDataset, self).__init__(\n",
    "           base_path=self.dataset_dir,\n",
    "           train_list_file=\"train_list.txt\",\n",
    "           validate_list_file=\"validate_list.txt\",\n",
    "           test_list_file=\"test_list.txt\",\n",
    "           #predict_file=\"predict_list.txt\",\n",
    "           label_list_file=\"label_list.txt\",\n",
    "           )\n",
    "dataset = DemoDataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-03-08 16:50:16,859-INFO: font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']\n",
      "2021-03-08 16:50:17,212-INFO: generated new fontManager\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 1000x1000 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                许佳琪                                     虞书欣\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt \n",
    "import matplotlib.image as mpimg\n",
    "img = mpimg.imread('data/xu/1.jpg') \n",
    "img1 = mpimg.imread('data/yu/1.jpg') \n",
    "plt.figure(figsize=(10,10))\n",
    "plt.subplot(1,2,1)\n",
    "plt.imshow(img)\n",
    "plt.axis('off') \n",
    "plt.subplot(1,2,2)\n",
    "plt.imshow(img1)\n",
    "plt.axis('off') \n",
    "plt.show()\n",
    "print(\"                许佳琪                                     虞书欣\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Step4、生成数据读取器\n",
    "\n",
    "接着生成一个图像分类的reader，reader负责将dataset的数据进行预处理，接着以特定格式组织并输入给模型进行训练。\n",
    "\n",
    "当我们生成一个图像分类的reader时，需要指定输入图片的大小"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[2021-03-08 16:50:17,521] [    INFO] - Dataset label map = {'许佳琪': 0, '虞书欣': 1}\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "data_reader = hub.reader.ImageClassificationReader(\n",
    "    image_width=module.get_expected_image_width(),\n",
    "    image_height=module.get_expected_image_height(),\n",
    "    images_mean=module.get_pretrained_images_mean(),\n",
    "    images_std=module.get_pretrained_images_std(),\n",
    "    dataset=dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Step5、配置策略\n",
    "在进行Finetune前，我们可以设置一些运行时的配置，例如如下代码中的配置，表示：\n",
    "\n",
    "* `use_cuda`：设置为False表示使用CPU进行训练。如果您本机支持GPU，且安装的是GPU版本的PaddlePaddle，我们建议您将这个选项设置为True；\n",
    "\n",
    "* `epoch`：迭代轮数；\n",
    "\n",
    "* `batch_size`：每次训练的时候，给模型输入的每批数据大小为32，模型训练时能够并行处理批数据，因此batch_size越大，训练的效率越高，但是同时带来了内存的负荷，过大的batch_size可能导致内存不足而无法训练，因此选择一个合适的batch_size是很重要的一步；\n",
    "\n",
    "* `log_interval`：每隔10 step打印一次训练日志；\n",
    "\n",
    "* `eval_interval`：每隔50 step在验证集上进行一次性能评估；\n",
    "\n",
    "* `checkpoint_dir`：将训练的参数和数据保存到cv_finetune_turtorial_demo目录中；\n",
    "\n",
    "* `strategy`：使用DefaultFinetuneStrategy策略进行finetune；\n",
    "\n",
    "更多运行配置，请查看[RunConfig](https://github.com/PaddlePaddle/PaddleHub/wiki/PaddleHub-API:-RunConfig)\n",
    "\n",
    "同时PaddleHub提供了许多优化策略，如`AdamWeightDecayStrategy`、`ULMFiTStrategy`、`DefaultFinetuneStrategy`等，详细信息参见[策略](https://github.com/PaddlePaddle/PaddleHub/wiki/PaddleHub-API:-Strategy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[2021-03-08 16:50:17,527] [    INFO] - Checkpoint dir: cv_finetune_turtorial_demo\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "config = hub.RunConfig(\n",
    "    use_cuda=True,                              #是否使用GPU训练，默认为False；\n",
    "    num_epoch=3,                                #Fine-tune的轮数；\n",
    "    checkpoint_dir=\"cv_finetune_turtorial_demo\",#模型checkpoint保存路径, 若用户没有指定，程序会自动生成；\n",
    "    batch_size=3,                              #训练的批大小，如果使用GPU，请根据实际情况调整batch_size；\n",
    "    eval_interval=10,                           #模型评估的间隔，默认每100个step评估一次验证集；\n",
    "    strategy=hub.finetune.strategy.DefaultFinetuneStrategy())  #Fine-tune优化策略；"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Step6、组建Finetune Task\n",
    "有了合适的预训练模型和准备要迁移的数据集后，我们开始组建一个Task。\n",
    "\n",
    "由于该数据设置是一个二分类的任务，而我们下载的分类module是在ImageNet数据集上训练的千分类模型，所以我们需要对模型进行简单的微调，把模型改造为一个二分类模型：\n",
    "\n",
    "1. 获取module的上下文环境，包括输入和输出的变量，以及Paddle Program；\n",
    "2. 从输出变量中找到特征图提取层feature_map；\n",
    "3. 在feature_map后面接入一个全连接层，生成Task；"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[2021-03-08 16:50:17,591] [    INFO] - 267 pretrained paramaters loaded by PaddleHub\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "input_dict, output_dict, program = module.context(trainable=True)\n",
    "img = input_dict[\"image\"]\n",
    "feature_map = output_dict[\"feature_map\"]\n",
    "feed_list = [img.name]\n",
    "\n",
    "task = hub.ImageClassifierTask(\n",
    "    data_reader=data_reader,\n",
    "    feed_list=feed_list,\n",
    "    feature=feature_map,\n",
    "    num_classes=dataset.num_labels,\n",
    "    config=config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Step5、开始Finetune\n",
    "\n",
    "我们选择`finetune_and_eval`接口来进行模型训练，这个接口在finetune的过程中，会周期性的进行模型效果的评估，以便我们了解整个训练过程的性能变化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[2021-03-08 16:50:18,758] [    INFO] - Strategy with slanted triangle learning rate, L2 regularization, \u001b[0m\n",
      "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py:811: UserWarning: There are no operators in the program to be executed. If you pass Program manually, please use fluid.program_guard to ensure the current Program is being used.\n",
      "  warnings.warn(error_info)\n",
      "\u001b[32m[2021-03-08 16:50:21,194] [    INFO] - Try loading checkpoint from cv_finetune_turtorial_demo/ckpt.meta\u001b[0m\n",
      "\u001b[32m[2021-03-08 16:50:21,195] [    INFO] - PaddleHub model checkpoint not found, start from scratch...\u001b[0m\n",
      "\u001b[32m[2021-03-08 16:50:21,245] [    INFO] - PaddleHub finetune start\u001b[0m\n",
      "\u001b[36m[2021-03-08 16:50:24,430] [   TRAIN] - step 10 / 30: loss=0.72682 acc=0.60000 [step/sec: 3.63]\u001b[0m\n",
      "\u001b[32m[2021-03-08 16:50:24,431] [    INFO] - Evaluation on dev dataset start\u001b[0m\n",
      "share_vars_from is set, scope is ignored.\n",
      "\u001b[34m[2021-03-08 16:50:25,115] [    EVAL] - [dev dataset evaluation result] loss=0.56531 acc=0.50000 [step/sec: 23.67]\u001b[0m\n",
      "\u001b[34m[2021-03-08 16:50:25,116] [    EVAL] - best model saved to cv_finetune_turtorial_demo/best_model [best acc=0.50000]\u001b[0m\n",
      "\u001b[36m[2021-03-08 16:50:28,918] [   TRAIN] - step 20 / 30: loss=0.53640 acc=0.80000 [step/sec: 3.74]\u001b[0m\n",
      "\u001b[32m[2021-03-08 16:50:28,919] [    INFO] - Evaluation on dev dataset start\u001b[0m\n",
      "\u001b[34m[2021-03-08 16:50:29,086] [    EVAL] - [dev dataset evaluation result] loss=0.49044 acc=1.00000 [step/sec: 93.25]\u001b[0m\n",
      "\u001b[34m[2021-03-08 16:50:29,086] [    EVAL] - best model saved to cv_finetune_turtorial_demo/best_model [best acc=1.00000]\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "run_states = task.finetune_and_eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Step6、预测\n",
    "\n",
    "当Finetune完成后，我们使用模型来进行预测，先通过以下命令来获取测试的图片"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt \n",
    "import matplotlib.image as mpimg\n",
    "\n",
    "data = [\"data/xu/16.jpg\",\"data/yu/16.jpg\"]\n",
    "\n",
    "label_map = dataset.label_dict()\n",
    "index = 0\n",
    "run_states = task.predict(data=data)\n",
    "results = [run_state.run_results for run_state in run_states]\n",
    "\n",
    "for batch_result in results:\n",
    "    print(batch_result)\n",
    "    batch_result = np.argmax(batch_result, axis=2)[0]\n",
    "    print(batch_result)\n",
    "    for result in batch_result:\n",
    "        index += 1\n",
    "        result = label_map[result]\n",
    "        print(\"input %i is %s, and the predict result is %s\" %\n",
    "              (index, data[index - 1], result))\n",
    "\n",
    "img = mpimg.imread(data[0]) \n",
    "img1 = mpimg.imread(data[1]) \n",
    "plt.figure(figsize=(10,10))\n",
    "plt.subplot(1,2,1)\n",
    "plt.imshow(img)\n",
    "plt.axis('off') \n",
    "plt.subplot(1,2,2)\n",
    "plt.imshow(img1)\n",
    "plt.axis('off') \n",
    "plt.show()\n",
    "print(\"             input1 许佳琪                             input2 虞书欣\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PaddlePaddle 1.7.2 (Python 3.5)",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
